Skip to content

feat(onnx): add input_check flag to LightningModule.to_onnx#21716

Open
nuemaan wants to merge 1 commit into
Lightning-AI:masterfrom
nuemaan:add-onnx-input-check
Open

feat(onnx): add input_check flag to LightningModule.to_onnx#21716
nuemaan wants to merge 1 commit into
Lightning-AI:masterfrom
nuemaan:add-onnx-input-check

Conversation

@nuemaan
Copy link
Copy Markdown

@nuemaan nuemaan commented May 16, 2026

What does this PR do?

Adds an input_check: bool = False argument to LightningModule.to_onnx. When set, after torch.onnx.export returns we load the saved model and run onnx.checker.check_model on it, so the caller gets a ValidationError instead of finding out at deploy time that the exported protobuf is malformed.

Default behavior is unchanged: input_check=False.

Refs #7279.

Notes on scope

The original issue lists two things: the spec check (onnx.checker.check_model) and a PyTorch-vs-onnxruntime output comparison. I only did the first one here — the runtime comparison needs onnxruntime as a hard dep and a way to deal with multi-output / dict-output models, which feels like a separate PR. Happy to add it as a follow-up if you'd like.

Behavior

  • file_path is a str/Path → loaded with onnx.load, then checked.
  • file_path is BytesIO → loaded with onnx.load_model_from_string; the buffer's cursor position is restored so the caller sees the same state as without the flag.
  • file_path is None → raises ValueError; there is nothing to load.
  • dynamo=True → raises ValueError; that path returns an ONNXProgram rather than a standalone protobuf file we can hand to the checker.

Tests

Added to tests/tests_pytorch/models/test_onnx.py:

  • test_input_check_runs_onnx_checker — happy path for both file and BytesIO.
  • test_input_check_raises_without_file_pathValueError when file_path=None.
  • test_input_check_detects_invalid_model — monkeypatches onnx.checker.check_model to fail; verifies the error propagates.
  • test_input_check_rejects_dynamoValueError when combined with dynamo=True.

All onnx tests pass locally on macOS (torch 2.12.0, onnx 1.21.0, onnxruntime 1.26.0):

13 passed, 5 skipped

Ruff lint + format pass.

Before submitting

  • Was this discussed/agreed via a GitHub issue? — yes, Add checks for model spec and matching output values in to_onnx() method #7279
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (added a short note in docs/source-pytorch/deploy/production_advanced.rst)
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request? — none, new kwarg defaults to False.
  • Did you update the CHANGELOG?

PR review

Anyone in the community is welcome to review the PR.


📚 Documentation preview 📚: https://pytorch-lightning--21716.org.readthedocs.build/en/21716/

Optional onnx.checker.check_model pass after export, gated behind
input_check=False to keep the default path unchanged. Validates the
saved file via onnx.load (or onnx.load_model_from_string for BytesIO)
and surfaces ValidationError to the caller.

Disallowed with dynamo=True since that path returns an ONNXProgram
and does not always produce a standalone protobuf file the checker
can load; also rejected when file_path is None.

Refs Lightning-AI#7279.
@github-actions github-actions Bot added docs Documentation related pl Generic label for PyTorch Lightning package labels May 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

docs Documentation related pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant